{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "57eebd05",
      "metadata": {
        "id": "57eebd05"
      },
      "source": [
        "# Synthetic Dataset Generator with Quality Scoring\n",
        "\n",
        "An AI-powered tool that creates realistic synthetic datasets for any business case with flexible schema creation, synonym permutation for diversity, and automated quality scoring.\n",
        "\n",
        "## Features\n",
        "- **Multi-Model Support**: HuggingFace models (primary) + Commercial APIs\n",
        "- **Flexible Schema Creation**: LLM-generated, manual, or hybrid approaches\n",
        "- **Synonym Permutation**: Post-process datasets to increase diversity\n",
        "- **Quality Scoring**: Separate LLM model evaluates dataset quality\n",
        "- **GPU Optimized**: Designed for Google Colab T4 GPUs\n",
        "- **Multiple Output Formats**: CSV, TSV, JSON, JSONL\n",
        "\n",
        "## Quick Start\n",
        "1. **Schema Tab**: Define your dataset structure\n",
        "2. **Generation Tab**: Generate synthetic data\n",
        "3. **Permutation Tab**: Add diversity with synonyms\n",
        "4. **Scoring Tab**: Evaluate data quality\n",
        "5. **Export Tab**: Download your dataset\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "a1673e5a",
      "metadata": {
        "id": "a1673e5a"
      },
      "outputs": [],
      "source": [
        "# Install dependencies\n",
        "%pip install -q --upgrade bitsandbytes accelerate transformers\n",
        "%pip install -q openai gradio nltk\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "m-yhYlN4OQEC",
      "metadata": {
        "id": "m-yhYlN4OQEC"
      },
      "outputs": [],
      "source": [
        "gpu_info = !nvidia-smi\n",
        "gpu_info = '\\n'.join(gpu_info)\n",
        "if gpu_info.find('failed') >= 0:\n",
        "  print('Not connected to a GPU')\n",
        "else:\n",
        "  print(gpu_info)\n",
        "  if gpu_info.find('Tesla T4') >= 0:\n",
        "    print(\"Success - Connected to a T4\")\n",
        "  else:\n",
        "    print(\"NOT CONNECTED TO A T4\")"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Start"
      ],
      "metadata": {
        "id": "jokJ6H7o5qaF"
      },
      "id": "jokJ6H7o5qaF"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "5ab3109c",
      "metadata": {
        "id": "5ab3109c"
      },
      "outputs": [],
      "source": [
        "# Imports and Setup\n",
        "import os\n",
        "import io\n",
        "import time\n",
        "import json\n",
        "import pandas as pd\n",
        "import random\n",
        "import re\n",
        "import gc\n",
        "import torch\n",
        "from typing import List, Dict, Any, Tuple\n",
        "import warnings\n",
        "warnings.filterwarnings(\"ignore\")\n",
        "\n",
        "# Google Colab\n",
        "from google.colab import files\n",
        "\n",
        "# LLM APIs\n",
        "from openai import OpenAI\n",
        "\n",
        "# HuggingFace\n",
        "from huggingface_hub import login\n",
        "from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig\n",
        "\n",
        "# Data processing\n",
        "import nltk\n",
        "from nltk.corpus import wordnet\n",
        "\n",
        "# UI\n",
        "import gradio as gr\n",
        "\n",
        "# Download NLTK data\n",
        "try:\n",
        "    nltk.download('wordnet', quiet=True)\n",
        "    nltk.download('omw-1.4', quiet=True)\n",
        "except:\n",
        "    print(\"NLTK data download may have failed - synonym features may not work\")\n",
        "\n",
        "print(\"✅ All imports successful!\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "a206f9d4",
      "metadata": {
        "id": "a206f9d4"
      },
      "outputs": [],
      "source": [
        "# API Key Setup - Support both Colab and Local environments\n",
        "def setup_api_keys():\n",
        "    \"\"\"Initialize API keys from environment or Colab secrets\"\"\"\n",
        "    try:\n",
        "        # Try Colab environment first\n",
        "        from google.colab import userdata\n",
        "        api_keys = {\n",
        "            'openai': userdata.get('OPENAI_API_KEY'),\n",
        "            'anthropic': userdata.get('ANTHROPIC_API_KEY'),\n",
        "            'google': userdata.get('GOOGLE_API_KEY'),\n",
        "            'deepseek': userdata.get('DEEPSEEK_API_KEY'),\n",
        "            # 'groq': userdata.get('GROQ_API_KEY'),\n",
        "            'grok': userdata.get('GROK_API_KEY'),\n",
        "            # 'openrouter': userdata.get('OPENROUTER_API_KEY'),\n",
        "            # 'ollama': userdata.get('OLLAMA_API_KEY'),\n",
        "            'hf_token': userdata.get('HF_TOKEN')\n",
        "        }\n",
        "        print(\"✅ Using Colab secrets\")\n",
        "    except:\n",
        "        # Fallback to local environment\n",
        "        from dotenv import load_dotenv\n",
        "        load_dotenv()\n",
        "        api_keys = {\n",
        "            'openai': os.getenv('OPENAI_API_KEY'),\n",
        "            'anthropic': os.getenv('ANTHROPIC_API_KEY'),\n",
        "            'google': os.getenv('GOOGLE_API_KEY'),\n",
        "            'deepseek': os.getenv('DEEPSEEK_API_KEY'),\n",
        "            # 'groq': os.getenv('GROQ_API_KEY'),\n",
        "            'grok': os.getenv('GROK_API_KEY'),\n",
        "            # 'openrouter': os.getenv('OPENROUTER_API_KEY'),\n",
        "            # 'ollama': os.getenv('OLLAMA_API_KEY'),\n",
        "            'hf_token': os.getenv('HF_TOKEN')\n",
        "        }\n",
        "        print(\"✅ Using local .env file\")\n",
        "\n",
        "    # Initialize API clients\n",
        "    anthropic_url = \"https://api.anthropic.com/v1/\"\n",
        "    gemini_url = \"https://generativelanguage.googleapis.com/v1beta/openai/\"\n",
        "    deepseek_url = \"https://api.deepseek.com\"\n",
        "    # groq_url = \"https://api.groq.com/openai/v1\"\n",
        "    grok_url = \"https://api.x.ai/v1\"\n",
        "    # openrouter_url = \"https://openrouter.ai/api/v1\"\n",
        "    # ollama_url = \"http://localhost:11434/v1\"\n",
        "\n",
        "    clients = {}\n",
        "    if api_keys['openai']:\n",
        "        clients['openai'] = OpenAI(api_key=api_keys['openai'])\n",
        "    if api_keys['anthropic']:\n",
        "        clients['anthropic'] = OpenAI(api_key=api_keys['anthropic'], base_url=anthropic_url)\n",
        "        # clients['anthropic'] = anthropic.Anthropic(api_key=api_keys['anthropic'])\n",
        "    if api_keys['google']:\n",
        "        # genai.configure(api_key=api_keys['google'])\n",
        "        clients['google'] = OpenAI(api_key=api_keys['google'], base_url=gemini_url)\n",
        "    if api_keys['deepseek']:\n",
        "        clients['deepseek'] = OpenAI(api_key=api_keys['deepseek'], base_url=deepseek_url)\n",
        "        # clients['deepseek'] = DeepSeek(api_key=api_keys['deepseek'])\n",
        "    if api_keys['grok']:\n",
        "        clients['grok'] = OpenAI(api_key=api_keys['grok'], base_url=grok_url)\n",
        "    if api_keys['hf_token']:\n",
        "        login(api_keys['hf_token'], add_to_git_credential=True)\n",
        "\n",
        "    return api_keys, clients\n",
        "\n",
        "# Initialize API keys and clients\n",
        "api_keys, clients = setup_api_keys()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "5a791f39",
      "metadata": {
        "id": "5a791f39"
      },
      "outputs": [],
      "source": [
        "# Model Configuration\n",
        "\n",
        "# HuggingFace Models\n",
        "HUGGINGFACE_MODELS = {\n",
        "    \"Llama 3.1 8B\": {\n",
        "        \"model_id\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
        "        \"description\": \"Good for structured data generation\",\n",
        "        \"size\": \"8B\",\n",
        "        \"type\": \"huggingface\",\n",
        "        \"model_class\": \"LlamaForCausalLM\"\n",
        "    },\n",
        "    \"Llama 3.2 3B\": {\n",
        "        \"model_id\": \"meta-llama/Llama-3.2-3B-Instruct\",\n",
        "        \"description\": \"Smaller and faster model for simple schemas\",\n",
        "        \"size\": \"3B\",\n",
        "        \"type\": \"huggingface\",\n",
        "        \"model_class\": \"LlamaForCausalLM\"\n",
        "    },\n",
        "    \"Phi-3.5 Mini\": {\n",
        "        \"model_id\": \"microsoft/Phi-3.5-mini-instruct\",\n",
        "        \"description\": \"Reasoning capabilities\",\n",
        "        \"size\": \"3.8B\",\n",
        "        \"type\": \"huggingface\",\n",
        "        \"model_class\": \"Phi3ForCausalLM\"\n",
        "    },\n",
        "    \"Gemma 2 9B\": {\n",
        "        \"model_id\": \"google/gemma-2-9b-it\",\n",
        "        \"description\": \"Instruction-tuned model\",\n",
        "        \"size\": \"9B\",\n",
        "        \"type\": \"huggingface\",\n",
        "        \"model_class\": \"GemmaForCausalLM\"\n",
        "    },\n",
        "    \"Qwen 2.5 7B\": {\n",
        "        \"model_id\": \"Qwen/Qwen2.5-7B-Instruct\",\n",
        "        \"description\": \"Multilingual that is good for diverse data\",\n",
        "        \"size\": \"7B\",\n",
        "        \"type\": \"huggingface\",\n",
        "        \"model_class\": \"Qwen2ForCausalLM\"\n",
        "    },\n",
        "    \"Mistral 7B\": {\n",
        "        \"model_id\": \"mistralai/Mistral-7B-Instruct-v0.3\",\n",
        "        \"description\": \"Fast inference\",\n",
        "        \"size\": \"7B\",\n",
        "        \"type\": \"huggingface\",\n",
        "        \"model_class\": \"MistralForCausalLM\"\n",
        "    },\n",
        "    \"Zephyr 7B\": {\n",
        "        \"model_id\": \"HuggingFaceH4/zephyr-7b-beta\",\n",
        "        \"description\": \"Fine-tuned for instruction following\",\n",
        "        \"size\": \"7B\",\n",
        "        \"type\": \"huggingface\",\n",
        "        \"model_class\": \"ZephyrForCausalLM\"\n",
        "    }\n",
        "}\n",
        "\n",
        "# Commercial Models\n",
        "COMMERCIAL_MODELS = {\n",
        "    \"GPT-5 Mini\": {\n",
        "        \"model_id\": \"gpt-5-mini\",\n",
        "        \"description\": \"Fast, cost-effective OpenAI model\",\n",
        "        \"provider\": \"openai\",\n",
        "        \"type\": \"commercial\"\n",
        "    },\n",
        "    \"Claude 4.5 Haiku\": {\n",
        "        \"model_id\": \"claude-4.5-haiku-20251001\",\n",
        "        \"description\": \"Balance of speed and quality\",\n",
        "        \"provider\": \"anthropic\",\n",
        "        \"type\": \"commercial\"\n",
        "    },\n",
        "    \"Gemini 2.5 Flash\": {\n",
        "        \"model_id\": \"gemini-2.5-flash-lite\",\n",
        "        \"description\": \"Fast Google model\",\n",
        "        \"provider\": \"google\",\n",
        "        \"type\": \"commercial\"\n",
        "    },\n",
        "    \"DeepSeek Chat\": {\n",
        "        \"model_id\": \"deepseek-chat\",\n",
        "        \"description\": \"Cost-effective with good performance\",\n",
        "        \"provider\": \"deepseek\",\n",
        "        \"type\": \"commercial\"\n",
        "    },\n",
        "    \"Grok 4\": {\n",
        "        \"model_id\": \"grok-4\",\n",
        "        \"description\": \"Grok 4\",\n",
        "        \"provider\": \"grok\",\n",
        "        \"type\": \"commercial\"\n",
        "    }\n",
        "}\n",
        "\n",
        "# Output formats\n",
        "OUTPUT_FORMATS = [\".csv\", \".tsv\", \".json\", \".jsonl\"]\n",
        "\n",
        "# Default schema for pharmacogenomics (PGx) example\n",
        "DEFAULT_SCHEMA = [\n",
        "    (\"patient_id\", \"TEXT\", \"Unique patient identifier\", \"PGX_001\"),\n",
        "    (\"age\", \"INT\", \"Patient age in years\", 45),\n",
        "    (\"gender\", \"TEXT\", \"Patient gender\", \"Female\"),\n",
        "    (\"ethnicity\", \"TEXT\", \"Patient ethnicity\", \"Caucasian\"),\n",
        "    (\"gene_variant\", \"TEXT\", \"Genetic variant\", \"CYP2D6*1/*4\"),\n",
        "    (\"drug_name\", \"TEXT\", \"Medication name\", \"Warfarin\"),\n",
        "    (\"dosage\", \"TEXT\", \"Drug dosage\", \"5mg daily\"),\n",
        "    (\"adverse_reaction\", \"TEXT\", \"Any adverse reactions\", \"None\"),\n",
        "    (\"efficacy_score\", \"INT\", \"Treatment efficacy (1-10)\", 8),\n",
        "    (\"metabolizer_status\", \"TEXT\", \"Drug metabolizer phenotype\", \"Intermediate\")\n",
        "]\n",
        "\n",
        "DEFAULT_SCHEMA_TEXT = \"\\n\".join([f\"{i+1}. {col[0]} ({col[1]}) - {col[2]}, example: {col[3]}\" for i, col in enumerate(DEFAULT_SCHEMA)])\n",
        "\n",
        "print(f\"📊 Available HuggingFace models: {len(HUGGINGFACE_MODELS)}\")\n",
        "print(f\"🌐 Available Commercial models: {len(COMMERCIAL_MODELS)}\")\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# HuggingFace Model Loading\n",
        "def load_huggingface_model(model_id, model_class_name, quantization_config, torch_dtype):\n",
        "    \"\"\"Load HuggingFace model with correct model class\"\"\"\n",
        "    try:\n",
        "        # Import the specific model class\n",
        "        if model_class_name == \"LlamaForCausalLM\":\n",
        "            from transformers import LlamaForCausalLM\n",
        "            model_class = LlamaForCausalLM\n",
        "        elif model_class_name == \"Phi3ForCausalLM\":\n",
        "            from transformers import Phi3ForCausalLM\n",
        "            model_class = Phi3ForCausalLM\n",
        "        elif model_class_name == \"GemmaForCausalLM\":\n",
        "            from transformers import GemmaForCausalLM\n",
        "            model_class = GemmaForCausalLM\n",
        "        elif model_class_name == \"Qwen2ForCausalLM\":\n",
        "            from transformers import Qwen2ForCausalLM\n",
        "            model_class = Qwen2ForCausalLM\n",
        "        elif model_class_name == \"MistralForCausalLM\":\n",
        "            from transformers import MistralForCausalLM\n",
        "            model_class = MistralForCausalLM\n",
        "        else:\n",
        "            # Fallback to AutoModelForCausalLM\n",
        "            model_class = AutoModelForCausalLM\n",
        "\n",
        "        # Load the model\n",
        "        model = model_class.from_pretrained(\n",
        "            model_id,\n",
        "            device_map=\"auto\",\n",
        "            quantization_config=quantization_config,\n",
        "            torch_dtype=torch_dtype\n",
        "        )\n",
        "        return model\n",
        "\n",
        "    except Exception as e:\n",
        "        print(f\"Error loading {model_class_name}: {str(e)}\")\n",
        "        # Fallback to AutoModelForCausalLM\n",
        "        try:\n",
        "            model = AutoModelForCausalLM.from_pretrained(\n",
        "                model_id,\n",
        "                device_map=\"auto\",\n",
        "                quantization_config=quantization_config,\n",
        "                torch_dtype=torch_dtype\n",
        "            )\n",
        "            return model\n",
        "        except Exception as e2:\n",
        "            raise Exception(f\"Failed to load model with both specific and auto classes: {str(e2)}\")"
      ],
      "metadata": {
        "id": "NaShTv335Zjr"
      },
      "id": "NaShTv335Zjr",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "quantization_config = BitsAndBytesConfig(\n",
        "    load_in_4bit=True,\n",
        "    bnb_4bit_use_double_quant=True,\n",
        "    bnb_4bit_compute_dtype=torch.bfloat16,\n",
        "    bnb_4bit_quant_type=\"nf4\"\n",
        ")"
      ],
      "metadata": {
        "id": "7IRVMhT65axX"
      },
      "id": "7IRVMhT65axX",
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "5d2f459a",
      "metadata": {
        "id": "5d2f459a"
      },
      "outputs": [],
      "source": [
        "# Schema Management Module\n",
        "class SchemaManager:\n",
        "    \"\"\"Handles schema creation, parsing, and enhancement\"\"\"\n",
        "\n",
        "    def __init__(self):\n",
        "        self.current_schema = None\n",
        "        self.schema_text = None\n",
        "        self.quantization_config = quantization_config\n",
        "\n",
        "    def generate_schema_with_llm(self, business_case: str, model_name: str, temperature: float = 0.7) -> str:\n",
        "        \"\"\"Generate complete schema from business case using LLM\"\"\"\n",
        "        system_prompt = \"\"\"You are an expert data scientist. Given a business case, generate a comprehensive dataset schema.\n",
        "        Return the schema in this exact format:\n",
        "        field_name (TYPE) - Description, example: example_value\n",
        "\n",
        "        Include 8-12 relevant fields that would be useful for the business case.\n",
        "        Use realistic field names and appropriate data types (TEXT, INT, FLOAT, BOOLEAN, ARRAY).\n",
        "        Provide clear descriptions and realistic examples.\"\"\"\n",
        "\n",
        "        user_prompt = f\"\"\"\\n\\nBusiness case: {business_case}\n",
        "\n",
        "        Generate a dataset schema for this business case. Include fields that would be relevant for analysis and decision-making.\"\"\"\n",
        "\n",
        "        try:\n",
        "            response = self._query_llm(model_name, system_prompt, user_prompt, temperature)\n",
        "            self.schema_text = response\n",
        "            return response\n",
        "        except Exception as e:\n",
        "            return f\"Error generating schema: {str(e)}\"\n",
        "\n",
        "    def enhance_schema_with_llm(self, partial_schema: str, business_case: str, model_name: str, temperature: float = 0.7) -> str:\n",
        "        \"\"\"Enhance user-provided partial schema using LLM\"\"\"\n",
        "        system_prompt = \"\"\"You are an expert data scientist. Given a partial schema and business case, enhance it by:\n",
        "        1. Adding missing relevant fields\n",
        "        2. Improving field descriptions\n",
        "        3. Adding realistic examples\n",
        "        4. Ensuring proper data types\n",
        "\n",
        "        Return the enhanced schema in the same format as the original.\"\"\"\n",
        "\n",
        "        user_prompt = f\"\"\"\\n\\nBusiness case: {business_case}\n",
        "\n",
        "        Current partial schema:\n",
        "        {partial_schema}\n",
        "\n",
        "        Please enhance this schema by adding missing fields and improving the existing ones.\"\"\"\n",
        "\n",
        "        try:\n",
        "            response = self._query_llm(model_name, system_prompt, user_prompt, temperature)\n",
        "            self.schema_text = response\n",
        "            return response\n",
        "        except Exception as e:\n",
        "            return f\"Error enhancing schema: {str(e)}\"\n",
        "\n",
        "    def parse_manual_schema(self, schema_text: str) -> Dict[str, Any]:\n",
        "        \"\"\"Parse manually entered schema text\"\"\"\n",
        "        try:\n",
        "            lines = [line.strip() for line in schema_text.split('\\n') if line.strip()]\n",
        "            parsed_schema = []\n",
        "\n",
        "            for line in lines:\n",
        "                if re.match(r'^\\d+\\.', line):  # Skip line numbers\n",
        "                    line = re.sub(r'^\\d+\\.\\s*', '', line)\n",
        "\n",
        "                # Parse format: field_name (TYPE) - Description, example: example_value\n",
        "                match = re.match(r'^([^(]+)\\s*\\(([^)]+)\\)\\s*-\\s*([^,]+),\\s*example:\\s*(.+)$', line)\n",
        "                if match:\n",
        "                    field_name, field_type, description, example = match.groups()\n",
        "                    parsed_schema.append({\n",
        "                        'name': field_name.strip(),\n",
        "                        'type': field_type.strip(),\n",
        "                        'description': description.strip(),\n",
        "                        'example': example.strip()\n",
        "                    })\n",
        "\n",
        "            self.current_schema = parsed_schema\n",
        "            return parsed_schema\n",
        "        except Exception as e:\n",
        "            return {\"error\": f\"Error parsing schema: {str(e)}\"}\n",
        "\n",
        "    def format_schema_for_prompt(self, schema: List[Dict]) -> str:\n",
        "        \"\"\"Convert parsed schema to prompt-ready format\"\"\"\n",
        "        if not schema:\n",
        "            return self.schema_text or \"\"\n",
        "\n",
        "        formatted_lines = []\n",
        "        for i, field in enumerate(schema, 1):\n",
        "            line = f\"{i}. {field['name']} ({field['type']}) - {field['description']}, example: {field['example']}\"\n",
        "            formatted_lines.append(line)\n",
        "\n",
        "        return \"\\n\".join(formatted_lines)\n",
        "\n",
        "    def _query_llm(self, model_name: str, system_prompt: str, user_prompt: str, temperature: float) -> str:\n",
        "        \"\"\"Universal LLM query interface\"\"\"\n",
        "        # Check if it's a HuggingFace model\n",
        "        if model_name in HUGGINGFACE_MODELS:\n",
        "            return self._query_huggingface(model_name, system_prompt, user_prompt, temperature)\n",
        "        elif model_name in COMMERCIAL_MODELS:\n",
        "            return self._query_commercial(model_name, system_prompt, user_prompt, temperature)\n",
        "        else:\n",
        "            raise ValueError(f\"Unknown model: {model_name}\")\n",
        "\n",
        "    def _query_huggingface(self, model_name: str, system_prompt: str, user_prompt: str, temperature: float) -> str:\n",
        "        \"\"\"Query HuggingFace models\"\"\"\n",
        "        model_info = HUGGINGFACE_MODELS[model_name]\n",
        "        model_id = model_info[\"model_id\"]\n",
        "\n",
        "        try:\n",
        "            # Check if model is already loaded\n",
        "            if model_name not in dataset_generator.loaded_models:\n",
        "                print(f\"🔄 Loading {model_name} for schema generation...\")\n",
        "\n",
        "                # Load tokenizer\n",
        "                tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)\n",
        "                tokenizer.pad_token = tokenizer.eos_token\n",
        "                print(f\"Tokenizer loaded for {model_name}\")\n",
        "\n",
        "                # Load model with quantization using correct model class\n",
        "                model_class_name = model_info.get(\"model_class\", \"AutoModelForCausalLM\")\n",
        "                model = load_huggingface_model(\n",
        "                    model_id,\n",
        "                    model_class_name,\n",
        "                    dataset_generator.quantization_config,\n",
        "                    torch.bfloat16\n",
        "                )\n",
        "\n",
        "                dataset_generator.loaded_models[model_name] = {\n",
        "                    'model': model,\n",
        "                    'tokenizer': tokenizer\n",
        "                }\n",
        "                print(f\"✅ {model_name} loaded successfully for schema generation!\")\n",
        "\n",
        "            # Get model and tokenizer\n",
        "            model = dataset_generator.loaded_models[model_name]['model']\n",
        "            tokenizer = dataset_generator.loaded_models[model_name]['tokenizer']\n",
        "\n",
        "            # Prepare messages\n",
        "            messages = [\n",
        "                {\"role\": \"system\", \"content\": system_prompt},\n",
        "                {\"role\": \"user\", \"content\": user_prompt}\n",
        "            ]\n",
        "\n",
        "            # Tokenize\n",
        "            inputs = tokenizer.apply_chat_template(messages, return_tensors=\"pt\").to(\"cuda\")\n",
        "\n",
        "            # Generate\n",
        "            with torch.no_grad():\n",
        "                outputs = model.generate(\n",
        "                    inputs,\n",
        "                    max_new_tokens=2000,\n",
        "                    temperature=temperature,\n",
        "                    do_sample=True,\n",
        "                    pad_token_id=tokenizer.eos_token_id\n",
        "                )\n",
        "\n",
        "            # Decode response\n",
        "            response = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
        "\n",
        "            # Extract only the assistant's response\n",
        "            if \"<|assistant|>\" in response:\n",
        "                response = response.split(\"<|assistant|>\")[-1].strip()\n",
        "            elif \"assistant\" in response:\n",
        "                response = response.split(\"assistant\")[-1].strip()\n",
        "\n",
        "            return response\n",
        "\n",
        "        except Exception as e:\n",
        "            # Clean up on error\n",
        "            if model_name in dataset_generator.loaded_models:\n",
        "                del dataset_generator.loaded_models[model_name]\n",
        "                gc.collect()\n",
        "                torch.cuda.empty_cache()\n",
        "            raise Exception(f\"HuggingFace schema generation error: {str(e)}\")\n",
        "\n",
        "    def _query_commercial(self, model_name: str, system_prompt: str, user_prompt: str, temperature: float) -> str:\n",
        "        \"\"\"Query commercial API models\"\"\"\n",
        "        model_info = COMMERCIAL_MODELS[model_name]\n",
        "        provider = model_info[\"provider\"]\n",
        "        model_id = model_info[\"model_id\"]\n",
        "\n",
        "\n",
        "        try:\n",
        "            response = clients[provider].chat.completions.create(\n",
        "                model=model_id,\n",
        "                messages=[\n",
        "                    {\"role\": \"system\", \"content\": system_prompt},\n",
        "                    {\"role\": \"user\", \"content\": user_prompt}\n",
        "                ],\n",
        "                temperature = temperature if model_id != \"gpt-5-mini\" else 1.0\n",
        "            )\n",
        "            return response.choices[0].message.content\n",
        "\n",
        "        except Exception as e:\n",
        "            return f\"Error querying {model_name}: {str(e)}\"\n",
        "\n",
        "# Initialize schema manager\n",
        "schema_manager = SchemaManager()\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "dd37ee66",
      "metadata": {
        "id": "dd37ee66"
      },
      "outputs": [],
      "source": [
        "# Dataset Generation Module\n",
        "class DatasetGenerator:\n",
        "    \"\"\"Handles synthetic dataset generation using multiple LLM models\"\"\"\n",
        "\n",
        "    def __init__(self):\n",
        "        self.loaded_models = {}  # Cache for HuggingFace models\n",
        "        self.quantization_config = quantization_config\n",
        "\n",
        "    def generate_dataset(self, schema_text: str, business_case: str, model_name: str,\n",
        "                        temperature: float, num_records: int, examples: str = \"\") -> Tuple[str, List[Dict]]:\n",
        "        \"\"\"Generate synthetic dataset using specified model\"\"\"\n",
        "        try:\n",
        "            # Build generation prompt\n",
        "            prompt = self._build_generation_prompt(schema_text, business_case, num_records, examples)\n",
        "\n",
        "            # Query the model\n",
        "            response = self._query_llm(model_name, prompt, temperature)\n",
        "\n",
        "            # Parse JSONL response\n",
        "            records = self._parse_jsonl_response(response)\n",
        "\n",
        "            if not records:\n",
        "                return \"❌ Error: No valid records generated\", []\n",
        "\n",
        "            if len(records) < num_records:\n",
        "                return f\"⚠️ Warning: Generated {len(records)} records (requested {num_records})\", records\n",
        "\n",
        "            return f\"✅ Generated {len(records)} records successfully!\", records\n",
        "\n",
        "        except Exception as e:\n",
        "            return f\"❌ Error: {str(e)}\", []\n",
        "\n",
        "    def _build_generation_prompt(self, schema_text: str, business_case: str, num_records: int, examples: str) -> str:\n",
        "        \"\"\"Build the generation prompt\"\"\"\n",
        "        prompt = f\"\"\"You are a data generation expert. Generate {num_records} realistic records for the following business case:\n",
        "\n",
        "Business Case: {business_case}\n",
        "\n",
        "Schema:\n",
        "{schema_text}\n",
        "\n",
        "Requirements:\n",
        "- Generate exactly {num_records} records\n",
        "- Each record must be a valid JSON object\n",
        "- Do NOT repeat values across records\n",
        "- Make data realistic and diverse\n",
        "- Output only valid JSONL (one JSON object per line)\n",
        "- No additional text or explanations\n",
        "\n",
        "\"\"\"\n",
        "\n",
        "        if examples.strip():\n",
        "            prompt += f\"\"\"\n",
        "Examples to follow (but do NOT repeat these exact examples):\n",
        "{examples}\n",
        "\n",
        "\"\"\"\n",
        "\n",
        "        prompt += \"Generate the dataset now:\"\n",
        "        return prompt\n",
        "\n",
        "    def _query_llm(self, model_name: str, prompt: str, temperature: float) -> str:\n",
        "        \"\"\"Universal LLM query interface\"\"\"\n",
        "        if model_name in HUGGINGFACE_MODELS:\n",
        "            return self._query_huggingface(model_name, prompt, temperature)\n",
        "        elif model_name in COMMERCIAL_MODELS:\n",
        "            return self._query_commercial(model_name, prompt, temperature)\n",
        "        else:\n",
        "            raise ValueError(f\"Unknown model: {model_name}\")\n",
        "\n",
        "    def _query_huggingface(self, model_name: str, prompt: str, temperature: float) -> str:\n",
        "        \"\"\"Query HuggingFace models with GPU optimization\"\"\"\n",
        "        model_info = HUGGINGFACE_MODELS[model_name]\n",
        "        model_id = model_info[\"model_id\"]\n",
        "\n",
        "        try:\n",
        "            # Check if model is already loaded\n",
        "            if model_name not in self.loaded_models:\n",
        "                print(f\"🔄 Loading {model_name}...\")\n",
        "\n",
        "                # Load tokenizer\n",
        "                tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)\n",
        "                tokenizer.pad_token = tokenizer.eos_token\n",
        "\n",
        "                # Load model with quantization using correct model class\n",
        "                model_class_name = model_info.get(\"model_class\", \"AutoModelForCausalLM\")\n",
        "                model = load_huggingface_model(\n",
        "                    model_id,\n",
        "                    model_class_name,\n",
        "                    self.quantization_config,\n",
        "                    torch.bfloat16\n",
        "                )\n",
        "\n",
        "                self.loaded_models[model_name] = {\n",
        "                    'model': model,\n",
        "                    'tokenizer': tokenizer\n",
        "                }\n",
        "                print(f\"✅ {model_name} loaded successfully!\")\n",
        "\n",
        "            # Get model and tokenizer\n",
        "            model = self.loaded_models[model_name]['model']\n",
        "            tokenizer = self.loaded_models[model_name]['tokenizer']\n",
        "\n",
        "            # Prepare messages\n",
        "            messages = [\n",
        "                {\"role\": \"system\", \"content\": \"You are a helpful assistant that generates realistic datasets.\"},\n",
        "                {\"role\": \"user\", \"content\": prompt}\n",
        "            ]\n",
        "\n",
        "            # Tokenize\n",
        "            inputs = tokenizer.apply_chat_template(messages, return_tensors=\"pt\").to(\"cuda\")\n",
        "\n",
        "            # Generate\n",
        "            with torch.no_grad():\n",
        "                outputs = model.generate(\n",
        "                    inputs,\n",
        "                    max_new_tokens=4000,\n",
        "                    temperature=temperature,\n",
        "                    do_sample=True,\n",
        "                    pad_token_id=tokenizer.eos_token_id\n",
        "                )\n",
        "\n",
        "            # Decode response\n",
        "            response = tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
        "\n",
        "            # Extract only the assistant's response\n",
        "            if \"<|assistant|>\" in response:\n",
        "                response = response.split(\"<|assistant|>\")[-1].strip()\n",
        "            elif \"assistant\" in response:\n",
        "                response = response.split(\"assistant\")[-1].strip()\n",
        "\n",
        "            return response\n",
        "\n",
        "        except Exception as e:\n",
        "            # Clean up on error\n",
        "            if model_name in self.loaded_models:\n",
        "                del self.loaded_models[model_name]\n",
        "                gc.collect()\n",
        "                torch.cuda.empty_cache()\n",
        "            raise Exception(f\"HuggingFace model error: {str(e)}\")\n",
        "\n",
        "    def _query_commercial(self, model_name: str, prompt: str, temperature: float) -> str:\n",
        "        \"\"\"Query commercial API models\"\"\"\n",
        "        model_info = COMMERCIAL_MODELS[model_name]\n",
        "        provider = model_info[\"provider\"]\n",
        "        model_id = model_info[\"model_id\"]\n",
        "\n",
        "        try:\n",
        "            response = clients[provider].chat.completions.create(\n",
        "                model=model_id,\n",
        "                messages=[\n",
        "                    {\"role\": \"system\", \"content\": \"You are a helpful assistant that generates realistic datasets.\"},\n",
        "                    {\"role\": \"user\", \"content\": prompt}\n",
        "                ],\n",
        "                temperature = temperature if model_id != \"gpt-5-mini\" else 1.0\n",
        "            )\n",
        "            return response.choices[0].message.content\n",
        "\n",
        "        except Exception as e:\n",
        "            raise Exception(f\"Commercial API error: {str(e)}\")\n",
        "\n",
        "    def _parse_jsonl_response(self, response: str) -> List[Dict]:\n",
        "        \"\"\"Parse JSONL response and extract valid JSON records\"\"\"\n",
        "        records = []\n",
        "        lines = [line.strip() for line in response.strip().split('\\n') if line.strip()]\n",
        "\n",
        "        for line in lines:\n",
        "            # Skip non-JSON lines\n",
        "            if not line.startswith('{'):\n",
        "                continue\n",
        "\n",
        "            try:\n",
        "                record = json.loads(line)\n",
        "                if isinstance(record, dict):\n",
        "                    records.append(record)\n",
        "            except json.JSONDecodeError:\n",
        "                continue\n",
        "\n",
        "        return records\n",
        "\n",
        "    def unload_model(self, model_name: str):\n",
        "        \"\"\"Unload a HuggingFace model to free memory\"\"\"\n",
        "        if model_name in self.loaded_models:\n",
        "            del self.loaded_models[model_name]\n",
        "            gc.collect()\n",
        "            torch.cuda.empty_cache()\n",
        "            print(f\"✅ {model_name} unloaded from memory\")\n",
        "\n",
        "    def get_memory_usage(self) -> str:\n",
        "        \"\"\"Get current GPU memory usage\"\"\"\n",
        "        if torch.cuda.is_available():\n",
        "            allocated = torch.cuda.memory_allocated() / 1024**3\n",
        "            reserved = torch.cuda.memory_reserved() / 1024**3\n",
        "            return f\"GPU Memory: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved\"\n",
        "        return \"GPU not available\"\n",
        "\n",
        "# Initialize dataset generator\n",
        "dataset_generator = DatasetGenerator()\n",
        "print(\"✅ Dataset Generation Module loaded!\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "350a1468",
      "metadata": {
        "id": "350a1468"
      },
      "outputs": [],
      "source": [
        "# Quality Scoring Module\n",
        "class QualityScorer:\n",
        "    \"\"\"Evaluates dataset quality using separate LLM models\"\"\"\n",
        "\n",
        "    def __init__(self):\n",
        "        self.quality_rules = None\n",
        "        self.scoring_model = None\n",
        "\n",
        "    def extract_quality_rules(self, original_prompt: str, schema_text: str) -> str:\n",
        "        \"\"\"Extract quality criteria from the original generation prompt\"\"\"\n",
        "        rules = f\"\"\"Quality Assessment Rules for Dataset:\n",
        "\n",
        "1. **Schema Compliance (25 points)**\n",
        "   - All required fields from schema are present\n",
        "   - Data types match schema specifications\n",
        "   - No missing values in critical fields\n",
        "\n",
        "2. **Uniqueness (20 points)**\n",
        "   - No duplicate records\n",
        "   - Diverse values across records\n",
        "   - Avoid repetitive patterns\n",
        "\n",
        "3. **Relevance to Business Case (25 points)**\n",
        "   - Data aligns with business context\n",
        "   - Realistic scenarios and values\n",
        "   - Appropriate level of detail\n",
        "\n",
        "4. **Realism and Coherence (20 points)**\n",
        "   - Values are realistic and plausible\n",
        "   - Internal consistency within records\n",
        "   - Logical relationships between fields\n",
        "\n",
        "5. **Diversity (10 points)**\n",
        "   - Varied values across the dataset\n",
        "   - Different scenarios represented\n",
        "   - Balanced distribution where appropriate\n",
        "\n",
        "Schema Requirements:\n",
        "{schema_text}\n",
        "\n",
        "Original Business Case Context:\n",
        "{original_prompt}\n",
        "\n",
        "Score each record from 0-100 based on these criteria.\"\"\"\n",
        "\n",
        "        self.quality_rules = rules\n",
        "        return rules\n",
        "\n",
        "    def score_single_record(self, record: Dict, model_name: str, temperature: float = 0.3) -> int:\n",
        "        \"\"\"Score a single dataset record (0-100)\"\"\"\n",
        "        if not self.quality_rules:\n",
        "            return 0\n",
        "\n",
        "        try:\n",
        "            # Prepare scoring prompt\n",
        "            prompt = f\"\"\"{self.quality_rules}\n",
        "\n",
        "Record to evaluate:\n",
        "{json.dumps(record, indent=2)}\n",
        "\n",
        "Provide a score from 0-100 and brief explanation. Format: \"Score: XX - Explanation\" \"\"\"\n",
        "\n",
        "            # Query the scoring model\n",
        "            response = self._query_scoring_model(model_name, prompt, temperature)\n",
        "\n",
        "            # Extract score from response\n",
        "            score = self._extract_score_from_response(response)\n",
        "            return score\n",
        "\n",
        "        except Exception as e:\n",
        "            print(f\"Error scoring record: {e}\")\n",
        "            return 0\n",
        "\n",
        "    def score_dataset(self, dataset: List[Dict], model_name: str, temperature: float = 0.3) -> Tuple[List[int], Dict[str, Any]]:\n",
        "        \"\"\"Score all records in the dataset\"\"\"\n",
        "        if not dataset:\n",
        "            return [], {}\n",
        "\n",
        "        scores = []\n",
        "        total_score = 0\n",
        "\n",
        "        print(f\"🔄 Scoring {len(dataset)} records with {model_name}...\")\n",
        "\n",
        "        for i, record in enumerate(dataset):\n",
        "            score = self.score_single_record(record, model_name, temperature)\n",
        "            scores.append(score)\n",
        "            total_score += score\n",
        "\n",
        "            if (i + 1) % 10 == 0:\n",
        "                print(f\"   Scored {i + 1}/{len(dataset)} records...\")\n",
        "\n",
        "        # Calculate statistics\n",
        "        avg_score = total_score / len(scores) if scores else 0\n",
        "        min_score = min(scores) if scores else 0\n",
        "        max_score = max(scores) if scores else 0\n",
        "\n",
        "        # Count quality levels\n",
        "        excellent = sum(1 for s in scores if s >= 90)\n",
        "        good = sum(1 for s in scores if 70 <= s < 90)\n",
        "        fair = sum(1 for s in scores if 50 <= s < 70)\n",
        "        poor = sum(1 for s in scores if s < 50)\n",
        "\n",
        "        stats = {\n",
        "            'total_records': len(dataset),\n",
        "            'average_score': round(avg_score, 2),\n",
        "            'min_score': min_score,\n",
        "            'max_score': max_score,\n",
        "            'excellent_count': excellent,\n",
        "            'good_count': good,\n",
        "            'fair_count': fair,\n",
        "            'poor_count': poor,\n",
        "            'excellent_pct': round(excellent / len(dataset) * 100, 1),\n",
        "            'good_pct': round(good / len(dataset) * 100, 1),\n",
        "            'fair_pct': round(fair / len(dataset) * 100, 1),\n",
        "            'poor_pct': round(poor / len(dataset) * 100, 1)\n",
        "        }\n",
        "\n",
        "        return scores, stats\n",
        "\n",
        "    def generate_quality_report(self, scores: List[int], dataset: List[Dict],\n",
        "                             flagged_threshold: int = 70) -> Dict[str, Any]:\n",
        "        \"\"\"Generate comprehensive quality report\"\"\"\n",
        "        if not scores or not dataset:\n",
        "            return {\"error\": \"No data to analyze\"}\n",
        "\n",
        "        # Find flagged records (low quality)\n",
        "        flagged_records = []\n",
        "        for i, (record, score) in enumerate(zip(dataset, scores)):\n",
        "            if score < flagged_threshold:\n",
        "                flagged_records.append({\n",
        "                    'index': i,\n",
        "                    'score': score,\n",
        "                    'record': record\n",
        "                })\n",
        "\n",
        "        # Quality distribution\n",
        "        score_ranges = {\n",
        "            '90-100': sum(1 for s in scores if s >= 90),\n",
        "            '80-89': sum(1 for s in scores if 80 <= s < 90),\n",
        "            '70-79': sum(1 for s in scores if 70 <= s < 80),\n",
        "            '60-69': sum(1 for s in scores if 60 <= s < 70),\n",
        "            '50-59': sum(1 for s in scores if 50 <= s < 60),\n",
        "            '0-49': sum(1 for s in scores if s < 50)\n",
        "        }\n",
        "\n",
        "        report = {\n",
        "            'total_records': len(dataset),\n",
        "            'average_score': round(sum(scores) / len(scores), 2),\n",
        "            'flagged_count': len(flagged_records),\n",
        "            'flagged_percentage': round(len(flagged_records) / len(dataset) * 100, 1),\n",
        "            'score_distribution': score_ranges,\n",
        "            'flagged_records': flagged_records[:10],  # Limit to first 10 for display\n",
        "            'recommendations': self._generate_recommendations(scores, flagged_records)\n",
        "        }\n",
        "\n",
        "        return report\n",
        "\n",
        "    def _query_scoring_model(self, model_name: str, prompt: str, temperature: float) -> str:\n",
        "        \"\"\"Query the scoring model\"\"\"\n",
        "        # Use the same interface as dataset generation\n",
        "        if model_name in HUGGINGFACE_MODELS:\n",
        "            return dataset_generator._query_huggingface(model_name, prompt, temperature)\n",
        "        elif model_name in COMMERCIAL_MODELS:\n",
        "            return dataset_generator._query_commercial(model_name, prompt, temperature)\n",
        "        else:\n",
        "            raise ValueError(f\"Unknown scoring model: {model_name}\")\n",
        "\n",
        "    def _extract_score_from_response(self, response: str) -> int:\n",
        "        \"\"\"Extract numerical score from model response\"\"\"\n",
        "        # Look for patterns like \"Score: 85\" or \"85/100\" or just \"85\"\n",
        "        score_patterns = [\n",
        "            r'Score:\\s*(\\d+)',\n",
        "            r'(\\d+)/100',\n",
        "            r'(\\d+)\\s*points',\n",
        "            r'(\\d+)\\s*out of 100'\n",
        "        ]\n",
        "\n",
        "        for pattern in score_patterns:\n",
        "            match = re.search(pattern, response, re.IGNORECASE)\n",
        "            if match:\n",
        "                score = int(match.group(1))\n",
        "                return max(0, min(100, score))  # Clamp between 0-100\n",
        "\n",
        "        # If no pattern found, try to find any number in the response\n",
        "        numbers = re.findall(r'\\d+', response)\n",
        "        if numbers:\n",
        "            score = int(numbers[0])\n",
        "            return max(0, min(100, score))\n",
        "\n",
        "        return 50  # Default score if no number found\n",
        "\n",
        "    def _generate_recommendations(self, scores: List[int], flagged_records: List[Dict]) -> List[str]:\n",
        "        \"\"\"Generate recommendations based on quality analysis\"\"\"\n",
        "        recommendations = []\n",
        "\n",
        "        avg_score = sum(scores) / len(scores)\n",
        "\n",
        "        if avg_score < 70:\n",
        "            recommendations.append(\"Consider regenerating the dataset with a different model or parameters\")\n",
        "\n",
        "        if len(flagged_records) > len(scores) * 0.3:\n",
        "            recommendations.append(\"High number of low-quality records - review generation prompt\")\n",
        "\n",
        "        if max(scores) - min(scores) > 50:\n",
        "            recommendations.append(\"High variance in quality - consider more consistent generation approach\")\n",
        "\n",
        "        if avg_score >= 85:\n",
        "            recommendations.append(\"Excellent dataset quality - ready for use\")\n",
        "        elif avg_score >= 70:\n",
        "            recommendations.append(\"Good dataset quality - minor improvements possible\")\n",
        "        else:\n",
        "            recommendations.append(\"Dataset needs improvement - consider regenerating\")\n",
        "\n",
        "        return recommendations\n",
        "\n",
        "# Initialize quality scorer\n",
        "quality_scorer = QualityScorer()\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "756883cd",
      "metadata": {
        "id": "756883cd"
      },
      "outputs": [],
      "source": [
        "# Synonym Permutation Module\n",
        "class SynonymPermutator:\n",
        "    \"\"\"Handles synonym replacement to increase dataset diversity\"\"\"\n",
        "\n",
        "    def __init__(self):\n",
        "        self.synonym_cache = {}  # Cache for synonyms to avoid repeated lookups\n",
        "\n",
        "    def get_synonyms(self, word: str) -> List[str]:\n",
        "        \"\"\"Get synonyms for a word using NLTK WordNet\"\"\"\n",
        "        if word.lower() in self.synonym_cache:\n",
        "            return self.synonym_cache[word.lower()]\n",
        "\n",
        "        synonyms = set()\n",
        "        try:\n",
        "            for syn in wordnet.synsets(word.lower()):\n",
        "                for lemma in syn.lemmas():\n",
        "                    synonym = lemma.name().replace('_', ' ').lower()\n",
        "                    if synonym != word.lower() and len(synonym) > 2:\n",
        "                        synonyms.add(synonym)\n",
        "        except:\n",
        "            pass\n",
        "\n",
        "        # Filter out very similar words and keep only relevant ones\n",
        "        filtered_synonyms = []\n",
        "        for syn in synonyms:\n",
        "            if (len(syn) >= 3 and\n",
        "                syn != word.lower() and\n",
        "                not syn.endswith('ing') or word.endswith('ing') and\n",
        "                not syn.endswith('ed') or word.endswith('ed')):\n",
        "                filtered_synonyms.append(syn)\n",
        "\n",
        "        # Limit to 5 synonyms max\n",
        "        filtered_synonyms = filtered_synonyms[:5]\n",
        "        self.synonym_cache[word.lower()] = filtered_synonyms\n",
        "        return filtered_synonyms\n",
        "\n",
        "    def identify_text_fields(self, dataset: List[Dict]) -> List[str]:\n",
        "        \"\"\"Auto-detect text fields suitable for synonym permutation\"\"\"\n",
        "        if not dataset:\n",
        "            return []\n",
        "\n",
        "        text_fields = []\n",
        "        for key, value in dataset[0].items():\n",
        "            if isinstance(value, str) and len(value) > 3:\n",
        "                # Check if field contains meaningful text (not just IDs or codes)\n",
        "                if not re.match(r'^[A-Z0-9_\\-]+$', value) and not value.isdigit():\n",
        "                    text_fields.append(key)\n",
        "\n",
        "        return text_fields\n",
        "\n",
        "    def permute_with_synonyms(self, dataset: List[Dict], fields_to_permute: List[str],\n",
        "                            permutation_rate: float = 0.3) -> Tuple[List[Dict], Dict[str, int]]:\n",
        "        \"\"\"Replace words with synonyms in specified fields\"\"\"\n",
        "        if not dataset or not fields_to_permute:\n",
        "            return dataset, {}\n",
        "\n",
        "        permuted_dataset = []\n",
        "        replacement_stats = {field: 0 for field in fields_to_permute}\n",
        "\n",
        "        for record in dataset:\n",
        "            permuted_record = record.copy()\n",
        "\n",
        "            for field in fields_to_permute:\n",
        "                if field in record and isinstance(record[field], str):\n",
        "                    original_text = record[field]\n",
        "                    permuted_text = self._permute_text(original_text, permutation_rate)\n",
        "                    permuted_record[field] = permuted_text\n",
        "\n",
        "                    # Count replacements\n",
        "                    if original_text != permuted_text:\n",
        "                        replacement_stats[field] += 1\n",
        "\n",
        "            permuted_dataset.append(permuted_record)\n",
        "\n",
        "        return permuted_dataset, replacement_stats\n",
        "\n",
        "    def _permute_text(self, text: str, permutation_rate: float) -> str:\n",
        "        \"\"\"Permute words in text with synonyms\"\"\"\n",
        "        words = text.split()\n",
        "        if len(words) < 2:  # Skip very short texts\n",
        "            return text\n",
        "\n",
        "        num_replacements = max(1, int(len(words) * permutation_rate))\n",
        "        words_to_replace = random.sample(range(len(words)), min(num_replacements, len(words)))\n",
        "\n",
        "        permuted_words = words.copy()\n",
        "        for word_idx in words_to_replace:\n",
        "            word = words[word_idx]\n",
        "            # Clean word for synonym lookup\n",
        "            clean_word = re.sub(r'[^\\w]', '', word.lower())\n",
        "\n",
        "            if len(clean_word) > 3:  # Only replace meaningful words\n",
        "                synonyms = self.get_synonyms(clean_word)\n",
        "                if synonyms:\n",
        "                    chosen_synonym = random.choice(synonyms)\n",
        "                    # Preserve original capitalization and punctuation\n",
        "                    if word.isupper():\n",
        "                        chosen_synonym = chosen_synonym.upper()\n",
        "                    elif word.istitle():\n",
        "                        chosen_synonym = chosen_synonym.title()\n",
        "\n",
        "                    permuted_words[word_idx] = word.replace(clean_word, chosen_synonym)\n",
        "\n",
        "        return ' '.join(permuted_words)\n",
        "\n",
        "    def get_permutation_preview(self, text: str, permutation_rate: float = 0.3) -> str:\n",
        "        \"\"\"Get a preview of how text would look after permutation\"\"\"\n",
        "        return self._permute_text(text, permutation_rate)\n",
        "\n",
        "    def clear_cache(self):\n",
        "        \"\"\"Clear the synonym cache to free memory\"\"\"\n",
        "        self.synonym_cache.clear()\n",
        "\n",
        "# Initialize synonym permutator\n",
        "synonym_permutator = SynonymPermutator()\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "cda75e7c",
      "metadata": {
        "id": "cda75e7c"
      },
      "outputs": [],
      "source": [
        "# Output & Export Module\n",
        "class DatasetExporter:\n",
        "    \"\"\"Handles dataset export to multiple formats\"\"\"\n",
        "\n",
        "    def __init__(self):\n",
        "        self.current_dataset = None\n",
        "        self.current_scores = None\n",
        "        self.export_history = []\n",
        "\n",
        "    def save_dataset(self, records: List[Dict], file_format: str, filename: str) -> str:\n",
        "        \"\"\"Save dataset using Gradio File component approach - WORKING VERSION\"\"\"\n",
        "        if not records:\n",
        "            return None  # Return None to indicate no file\n",
        "\n",
        "        try:\n",
        "            # Ensure filename has correct extension\n",
        "            if not filename.endswith(file_format):\n",
        "                filename += file_format\n",
        "\n",
        "            # Generate unique filename to avoid caching issues\n",
        "            timestamp = int(time.time())\n",
        "            base_name = filename.replace(file_format, '')\n",
        "            unique_filename = f\"{base_name}_{timestamp}{file_format}\"\n",
        "\n",
        "            # Create file path in /content directory\n",
        "            file_path = f\"/content/{unique_filename}\"\n",
        "\n",
        "            # Create DataFrame\n",
        "            df = pd.DataFrame(records)\n",
        "\n",
        "            if file_format == \".csv\":\n",
        "                df.to_csv(file_path, index=False)\n",
        "            elif file_format == \".tsv\":\n",
        "                df.to_csv(file_path, sep=\"\\t\", index=False)\n",
        "            elif file_format == \".json\":\n",
        "                df.to_json(file_path, orient=\"records\", indent=2)\n",
        "            elif file_format == \".jsonl\":\n",
        "                with open(file_path, 'w') as f:\n",
        "                    for record in records:\n",
        "                        f.write(json.dumps(record) + '\\n')\n",
        "            else:\n",
        "                return None\n",
        "\n",
        "            print(f\"File generated and saved at: {file_path}\")\n",
        "            return file_path\n",
        "\n",
        "        except Exception as e:\n",
        "            print(f\"Error saving dataset: {str(e)}\")\n",
        "            return None\n",
        "\n",
        "    def save_with_scores(self, records: List[Dict], scores: List[int], file_format: str, filename: str) -> str:\n",
        "        \"\"\"Save dataset with quality scores using Gradio File component approach\"\"\"\n",
        "        if not records or not scores:\n",
        "            return None\n",
        "\n",
        "        try:\n",
        "            # Add scores to records\n",
        "            records_with_scores = []\n",
        "            for i, record in enumerate(records):\n",
        "                record_with_score = record.copy()\n",
        "                record_with_score['quality_score'] = scores[i] if i < len(scores) else 0\n",
        "                records_with_scores.append(record_with_score)\n",
        "\n",
        "            return self.save_dataset(records_with_scores, file_format, filename)\n",
        "\n",
        "        except Exception as e:\n",
        "            print(f\"Error saving dataset with scores: {str(e)}\")\n",
        "            return None\n",
        "\n",
        "    def export_quality_report(self, scores: List[int], dataset: List[Dict], filename: str) -> str:\n",
        "        \"\"\"Export quality report as JSON\"\"\"\n",
        "        try:\n",
        "            if not scores or not dataset:\n",
        "                return \"❌ Error: No data to analyze\"\n",
        "\n",
        "            # Generate quality report\n",
        "            report = quality_scorer.generate_quality_report(scores, dataset)\n",
        "\n",
        "            report['export_timestamp'] = pd.Timestamp.now().isoformat()\n",
        "            report['dataset_size'] = len(dataset)\n",
        "            report['score_statistics'] = {\n",
        "                'mean': round(sum(scores) / len(scores), 2),\n",
        "                'median': round(sorted(scores)[len(scores)//2], 2),\n",
        "                'std': round(pd.Series(scores).std(), 2)\n",
        "            }\n",
        "\n",
        "            # Save report\n",
        "            with open(filename, 'w') as f:\n",
        "                json.dump(report, f, indent=2)\n",
        "\n",
        "            return f\"✅ Quality report saved to {filename}\"\n",
        "\n",
        "        except Exception as e:\n",
        "            return f\"❌ Error saving quality report: {str(e)}\"\n",
        "\n",
        "    def create_preview_dataframe(self, records: List[Dict], num_rows: int = 20) -> pd.DataFrame:\n",
        "        \"\"\"Create preview DataFrame for display\"\"\"\n",
        "        if not records:\n",
        "            return pd.DataFrame()\n",
        "\n",
        "        df = pd.DataFrame(records)\n",
        "        return df.head(num_rows)\n",
        "\n",
        "    def get_dataset_summary(self, records: List[Dict]) -> Dict[str, Any]:\n",
        "        \"\"\"Get summary statistics for the dataset\"\"\"\n",
        "        if not records:\n",
        "            return {\"error\": \"No data available\"}\n",
        "\n",
        "        df = pd.DataFrame(records)\n",
        "\n",
        "        summary = {\n",
        "            'total_records': len(records),\n",
        "            'total_fields': len(df.columns),\n",
        "            'field_names': list(df.columns),\n",
        "            'data_types': df.dtypes.to_dict(),\n",
        "            'missing_values': df.isnull().sum().to_dict(),\n",
        "            'memory_usage': df.memory_usage(deep=True).sum(),\n",
        "            'sample_records': records[:3]  # First 3 records as sample\n",
        "        }\n",
        "\n",
        "        return summary\n",
        "\n",
        "    def get_export_history(self) -> List[Dict]:\n",
        "        \"\"\"Get history of all exports\"\"\"\n",
        "        return self.export_history.copy()\n",
        "\n",
        "    def clear_history(self):\n",
        "        \"\"\"Clear export history\"\"\"\n",
        "        self.export_history.clear()\n",
        "\n",
        "# Initialize dataset exporter\n",
        "dataset_exporter = DatasetExporter()\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "2a85481e",
      "metadata": {
        "id": "2a85481e"
      },
      "outputs": [],
      "source": [
        "# Global state variables\n",
        "current_dataset = []\n",
        "current_scores = []\n",
        "current_schema_text = DEFAULT_SCHEMA_TEXT\n",
        "current_business_case = \"Pharmacogenomics patient data for drug response analysis\"\n",
        "\n",
        "# Gradio UI Functions\n",
        "def generate_schema(business_case, schema_mode, schema_text, model_name, temperature):\n",
        "    \"\"\"Generate or enhance schema based on mode\"\"\"\n",
        "    if schema_mode == \"LLM Generate\":\n",
        "        result = schema_manager.generate_schema_with_llm(business_case, model_name, temperature)\n",
        "        current_schema_text = result\n",
        "        current_business_case = business_case\n",
        "        return result, result, result, business_case\n",
        "    elif schema_mode == \"LLM Enhance Manual\":\n",
        "        result = schema_manager.enhance_schema_with_llm(schema_text, business_case, model_name, temperature)\n",
        "        current_schema_text = result\n",
        "        current_business_case = business_case\n",
        "        return result, result, result, business_case\n",
        "    else:  # Manual Entry\n",
        "        current_schema_text = schema_text\n",
        "        current_business_case = business_case\n",
        "        return schema_text, schema_text, schema_text, business_case\n",
        "\n",
        "def generate_dataset_ui(schema_text, business_case, model_name, temperature, num_records, examples):\n",
        "    \"\"\"Generate dataset using selected model\"\"\"\n",
        "    global current_dataset\n",
        "\n",
        "    status, records = dataset_generator.generate_dataset(\n",
        "        schema_text, business_case, model_name, temperature, num_records, examples\n",
        "    )\n",
        "\n",
        "    current_dataset = records\n",
        "    preview_df = dataset_exporter.create_preview_dataframe(records, 20)\n",
        "\n",
        "    return status, preview_df, len(records)\n",
        "\n",
        "def apply_synonym_permutation(enable_permutation, fields_to_permute, permutation_rate):\n",
        "    \"\"\"Apply synonym permutation to dataset - FIXED VERSION\"\"\"\n",
        "    global current_dataset\n",
        "\n",
        "    if not enable_permutation:\n",
        "        return current_dataset, \"❌ Permutation is disabled - check the 'Enable Synonym Permutation' checkbox\"\n",
        "\n",
        "    if not current_dataset:\n",
        "        return [], \"❌ No dataset available - generate a dataset first\"\n",
        "\n",
        "    if not fields_to_permute:\n",
        "        # Try to auto-identify fields if none are selected\n",
        "        try:\n",
        "            auto_fields = synonym_permutator.identify_text_fields(current_dataset)\n",
        "            if auto_fields:\n",
        "                fields_to_permute = auto_fields[:2]  # Use first 2 fields as default\n",
        "                print(f\"DEBUG: Auto-selected fields: {fields_to_permute}\")\n",
        "            else:\n",
        "                return current_dataset, \"❌ No text fields found for permutation\"\n",
        "        except Exception as e:\n",
        "            return current_dataset, f\"❌ Error identifying fields: {str(e)}\"\n",
        "\n",
        "    try:\n",
        "        permuted_dataset, stats = synonym_permutator.permute_with_synonyms(\n",
        "            current_dataset, fields_to_permute, permutation_rate / 100\n",
        "        )\n",
        "\n",
        "        current_dataset = permuted_dataset\n",
        "\n",
        "        # Convert to DataFrame for proper display\n",
        "        import pandas as pd\n",
        "        preview_df = pd.DataFrame(permuted_dataset)\n",
        "\n",
        "        stats_text = f\"✅ Permutation applied to {len(fields_to_permute)} fields. \"\n",
        "        stats_text += f\"Replacement counts: {stats}\"\n",
        "\n",
        "        return preview_df, stats_text\n",
        "\n",
        "    except Exception as e:\n",
        "        print(f\"DEBUG: Error during permutation: {str(e)}\")\n",
        "        return current_dataset, f\"❌ Error during permutation: {str(e)}\"\n",
        "\n",
        "def score_dataset_quality(scoring_model, scoring_temperature):\n",
        "    \"\"\"Score dataset quality using selected model\"\"\"\n",
        "    global current_dataset, current_scores\n",
        "\n",
        "    if not current_dataset:\n",
        "        return \"No dataset available for scoring\", [], {}\n",
        "\n",
        "    # Extract quality rules\n",
        "    original_prompt = f\"Business case: {current_business_case}\"\n",
        "    rules = quality_scorer.extract_quality_rules(original_prompt, current_schema_text)\n",
        "\n",
        "    # Score dataset\n",
        "    scores, stats = quality_scorer.score_dataset(current_dataset, scoring_model, scoring_temperature)\n",
        "    current_scores = scores\n",
        "\n",
        "    # Create scores DataFrame for display\n",
        "    scores_df = pd.DataFrame({\n",
        "        'Record_Index': range(len(scores)),\n",
        "        'Quality_Score': scores,\n",
        "        'Quality_Level': ['Excellent' if s >= 90 else 'Good' if s >= 70 else 'Fair' if s >= 50 else 'Poor' for s in scores]\n",
        "    })\n",
        "\n",
        "    # Generate report\n",
        "    report = quality_scorer.generate_quality_report(scores, current_dataset)\n",
        "\n",
        "    status = f\"✅ Scored {len(scores)} records. Average score: {stats['average_score']}\"\n",
        "\n",
        "    return status, scores_df, report\n",
        "\n",
        "def export_dataset(file_format, filename, include_scores):\n",
        "    \"\"\"Export dataset to specified format\"\"\"\n",
        "    global current_dataset, current_scores\n",
        "\n",
        "    if not current_dataset:\n",
        "        return \"No dataset to export\"\n",
        "\n",
        "    try:\n",
        "        if include_scores and current_scores:\n",
        "            result = dataset_exporter.save_with_scores(current_dataset, current_scores, file_format, filename)\n",
        "        else:\n",
        "            result = dataset_exporter.save_dataset(current_dataset, file_format, filename)\n",
        "        return result\n",
        "    except Exception as e:\n",
        "        return f\"❌ Error exporting dataset: {str(e)}\"\n",
        "\n",
        "def get_available_fields():\n",
        "    \"\"\"Get available fields for permutation\"\"\"\n",
        "    if not current_dataset:\n",
        "        return []\n",
        "\n",
        "    return synonym_permutator.identify_text_fields(current_dataset)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Graddle"
      ],
      "metadata": {
        "id": "fDerxxf1zfpu"
      },
      "id": "fDerxxf1zfpu"
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "ccc985a6",
      "metadata": {
        "id": "ccc985a6"
      },
      "outputs": [],
      "source": [
        "# Create Gradio Interface\n",
        "def create_gradio_interface():\n",
        "    \"\"\"Create the main Gradio interface with 5 tabs\"\"\"\n",
        "\n",
        "    # Combine all models for dropdowns\n",
        "    all_models = list(COMMERCIAL_MODELS.keys())+list(HUGGINGFACE_MODELS.keys())\n",
        "\n",
        "    with gr.Blocks(title=\"Synthetic Dataset Generator\", theme=gr.themes.Soft()) as interface:\n",
        "\n",
        "        gr.Markdown(\"# Synthetic Dataset Generator with Quality Scoring\")\n",
        "        gr.Markdown(\"Generate realistic synthetic datasets using multiple LLM models with flexible schema creation, synonym permutation, and automated quality scoring.\")\n",
        "\n",
        "        # Status bar\n",
        "        with gr.Row():\n",
        "            gpu_status = gr.Textbox(\n",
        "                label=\"GPU Status\",\n",
        "                value=dataset_generator.get_memory_usage(),\n",
        "                interactive=False,\n",
        "                scale=1\n",
        "            )\n",
        "            current_status = gr.Textbox(\n",
        "                label=\"Current Status\",\n",
        "                value=\"Ready to generate datasets\",\n",
        "                interactive=False,\n",
        "                scale=2\n",
        "            )\n",
        "\n",
        "        # Tab 1: Schema Definition\n",
        "        with gr.Tab(\"📋 Schema Definition\"):\n",
        "            gr.Markdown(\"### Define your dataset schema\")\n",
        "\n",
        "            with gr.Row():\n",
        "                with gr.Column(scale=2):\n",
        "                    schema_mode = gr.Radio(\n",
        "                        choices=[\"LLM Generate\", \"Manual Entry\", \"LLM Enhance Manual\"],\n",
        "                        value=\"Manual Entry\",\n",
        "                        label=\"Schema Mode\"\n",
        "                    )\n",
        "\n",
        "                    business_case_input = gr.Textbox(\n",
        "                        label=\"Business Case\",\n",
        "                        value=current_business_case,\n",
        "                        lines=3,\n",
        "                        placeholder=\"Describe your business case or data requirements...\"\n",
        "                    )\n",
        "\n",
        "                    schema_input = gr.Textbox(\n",
        "                        label=\"Schema Definition\",\n",
        "                        value=current_schema_text,\n",
        "                        lines=15,\n",
        "                        placeholder=\"Define your dataset schema here...\"\n",
        "                    )\n",
        "\n",
        "                    with gr.Row():\n",
        "                        schema_model = gr.Dropdown(\n",
        "                            choices=all_models,\n",
        "                            value=all_models[0],\n",
        "                            label=\"Model for Schema Generation\"\n",
        "                        )\n",
        "                        schema_temperature = gr.Slider(\n",
        "                            minimum=0.0,\n",
        "                            maximum=2.0,\n",
        "                            value=0.7,\n",
        "                            step=0.1,\n",
        "                            label=\"Temperature\"\n",
        "                        )\n",
        "\n",
        "                    generate_schema_btn = gr.Button(\"🔄 Generate/Enhance Schema\", variant=\"primary\")\n",
        "\n",
        "                with gr.Column(scale=1):\n",
        "                    schema_output = gr.Textbox(\n",
        "                        label=\"Generated Schema\",\n",
        "                        lines=15,\n",
        "                        interactive=False\n",
        "                    )\n",
        "\n",
        "        # Tab 2: Dataset Generation\n",
        "        with gr.Tab(\"🚀 Dataset Generation\"):\n",
        "            gr.Markdown(\"### Generate synthetic dataset\")\n",
        "\n",
        "            with gr.Row():\n",
        "                with gr.Column(scale=2):\n",
        "                    generation_schema = gr.Textbox(\n",
        "                        label=\"Schema (from Tab 1)\",\n",
        "                        value=current_schema_text,\n",
        "                        lines=8,\n",
        "                        interactive=False\n",
        "                    )\n",
        "\n",
        "                    generation_business_case = gr.Textbox(\n",
        "                        label=\"Business Case\",\n",
        "                        value=current_business_case,\n",
        "                        lines=2\n",
        "                    )\n",
        "\n",
        "                    examples_input = gr.Textbox(\n",
        "                        label=\"Few-shot Examples (JSON format)\",\n",
        "                        lines=5,\n",
        "                        placeholder='[{\"instruction\": \"example\", \"response\": \"example\"}]',\n",
        "                        value=\"\"\n",
        "                    )\n",
        "\n",
        "                    with gr.Row():\n",
        "                        generation_model = gr.Dropdown(\n",
        "                            choices=all_models,\n",
        "                            value=all_models[0],\n",
        "                            label=\"Generation Model\"\n",
        "                        )\n",
        "                        generation_temperature = gr.Slider(\n",
        "                            minimum=0.0,\n",
        "                            maximum=2.0,\n",
        "                            value=0.7,\n",
        "                            step=0.1,\n",
        "                            label=\"Temperature\"\n",
        "                        )\n",
        "                        num_records = gr.Number(\n",
        "                            value=50,\n",
        "                            minimum=11,\n",
        "                            maximum=1000,\n",
        "                            step=1,\n",
        "                            label=\"Number of Records\"\n",
        "                        )\n",
        "\n",
        "                    generate_dataset_btn = gr.Button(\"🚀 Generate Dataset\", variant=\"primary\", size=\"lg\")\n",
        "\n",
        "                with gr.Column(scale=1):\n",
        "                    generation_status = gr.Textbox(\n",
        "                        label=\"Generation Status\",\n",
        "                        lines=3,\n",
        "                        interactive=False\n",
        "                    )\n",
        "\n",
        "                    dataset_preview = gr.Dataframe(\n",
        "                        label=\"Dataset Preview (First 20 rows)\",\n",
        "                        interactive=False,\n",
        "                        wrap=True\n",
        "                    )\n",
        "\n",
        "                    record_count = gr.Number(\n",
        "                        label=\"Total Records Generated\",\n",
        "                        interactive=False\n",
        "                    )\n",
        "\n",
        "        # Tab 3: Synonym Permutation\n",
        "        with gr.Tab(\"🔄 Synonym Permutation\"):\n",
        "            gr.Markdown(\"### Add diversity with synonym replacement\")\n",
        "\n",
        "            with gr.Row():\n",
        "                with gr.Column(scale=2):\n",
        "                    enable_permutation = gr.Checkbox(\n",
        "                        label=\"Enable Synonym Permutation\",\n",
        "                        value=False\n",
        "                    )\n",
        "\n",
        "                    fields_to_permute = gr.CheckboxGroup(\n",
        "                        label=\"Fields to Permute\",\n",
        "                        choices=[],\n",
        "                        value=[]\n",
        "                    )\n",
        "\n",
        "                    permutation_rate = gr.Slider(\n",
        "                        minimum=0,\n",
        "                        maximum=50,\n",
        "                        value=20,\n",
        "                        step=5,\n",
        "                        label=\"Permutation Rate (%)\"\n",
        "                    )\n",
        "\n",
        "                    apply_permutation_btn = gr.Button(\"🔄 Apply Permutation\", variant=\"secondary\")\n",
        "\n",
        "                with gr.Column(scale=1):\n",
        "                    permutation_status = gr.Textbox(\n",
        "                        label=\"Permutation Status\",\n",
        "                        lines=2,\n",
        "                        interactive=False\n",
        "                    )\n",
        "\n",
        "                permuted_preview = gr.Dataframe(\n",
        "                    label=\"Permuted Dataset Preview\",\n",
        "                    interactive=False,\n",
        "                    wrap=True,\n",
        "                    datatype=[\"str\"] * 10\n",
        "                )\n",
        "\n",
        "        # Tab 4: Quality Scoring\n",
        "        with gr.Tab(\"📊 Quality Scoring\"):\n",
        "            gr.Markdown(\"### Evaluate dataset quality\")\n",
        "\n",
        "            with gr.Row():\n",
        "                with gr.Column(scale=2):\n",
        "                    scoring_model = gr.Dropdown(\n",
        "                        choices=all_models,\n",
        "                        value=all_models[0],\n",
        "                        label=\"Scoring Model\"\n",
        "                    )\n",
        "\n",
        "                    scoring_temperature = gr.Slider(\n",
        "                        minimum=0.0,\n",
        "                        maximum=2.0,\n",
        "                        value=0.3,\n",
        "                        step=0.1,\n",
        "                        label=\"Temperature\"\n",
        "                    )\n",
        "\n",
        "                    score_dataset_btn = gr.Button(\"📊 Score Dataset Quality\", variant=\"primary\")\n",
        "\n",
        "                with gr.Column(scale=1):\n",
        "                    scoring_status = gr.Textbox(\n",
        "                        label=\"Scoring Status\",\n",
        "                        lines=2,\n",
        "                        interactive=False\n",
        "                    )\n",
        "\n",
        "                    scores_dataframe = gr.Dataframe(\n",
        "                        label=\"Quality Scores\",\n",
        "                        interactive=False\n",
        "                    )\n",
        "\n",
        "                    quality_report = gr.JSON(\n",
        "                        label=\"Quality Report\"\n",
        "                    )\n",
        "\n",
        "        with gr.Tab(\"💾 Export\"):\n",
        "            gr.Markdown(\"### Export your dataset\")\n",
        "\n",
        "            with gr.Row():\n",
        "                with gr.Column(scale=2):\n",
        "                    file_format = gr.Dropdown(\n",
        "                        choices=OUTPUT_FORMATS,\n",
        "                        value=\".csv\",\n",
        "                        label=\"File Format\"\n",
        "                    )\n",
        "\n",
        "                    filename = gr.Textbox(\n",
        "                        label=\"Filename\",\n",
        "                        value=\"synthetic_dataset\",\n",
        "                        placeholder=\"Enter filename (extension added automatically)\"\n",
        "                    )\n",
        "\n",
        "                    include_scores = gr.Checkbox(\n",
        "                        label=\"Include Quality Scores\",\n",
        "                        value=False\n",
        "                    )\n",
        "\n",
        "                    export_btn = gr.Button(\"💾 Export Dataset\", variant=\"primary\")\n",
        "\n",
        "                with gr.Column(scale=1):\n",
        "                    # Use gr.File component for download\n",
        "                    download_file = gr.File(\n",
        "                        label=\"Download your file here\",\n",
        "                        interactive=False,\n",
        "                        visible=True\n",
        "                    )\n",
        "\n",
        "                    export_status = gr.Textbox(\n",
        "                        label=\"Export Status\",\n",
        "                        lines=3,\n",
        "                        interactive=False\n",
        "                    )\n",
        "\n",
        "        # Event handlers\n",
        "        generate_schema_btn.click(\n",
        "            generate_schema,\n",
        "            inputs=[business_case_input, schema_mode, schema_input, schema_model, schema_temperature],\n",
        "            outputs=[schema_output, schema_input, generation_schema, generation_business_case]\n",
        "        )\n",
        "\n",
        "        generate_dataset_btn.click(\n",
        "            generate_dataset_ui,\n",
        "            inputs=[generation_schema, generation_business_case, generation_model, generation_temperature, num_records, examples_input],\n",
        "            outputs=[generation_status, dataset_preview, record_count]\n",
        "        )\n",
        "\n",
        "        apply_permutation_btn.click(\n",
        "            apply_synonym_permutation,\n",
        "            inputs=[enable_permutation, fields_to_permute, permutation_rate],\n",
        "            outputs=[permuted_preview, permutation_status]\n",
        "        )\n",
        "\n",
        "        score_dataset_btn.click(\n",
        "            score_dataset_quality,\n",
        "            inputs=[scoring_model, scoring_temperature],\n",
        "            outputs=[scoring_status, scores_dataframe, quality_report]\n",
        "        )\n",
        "\n",
        "\n",
        "        def export_dataset_with_file(file_format, filename, include_scores):\n",
        "              \"\"\"Export dataset with file download\"\"\"\n",
        "              global current_dataset, current_scores\n",
        "\n",
        "              if not current_dataset:\n",
        "                  return None, \"❌ No dataset to export\"\n",
        "\n",
        "              try:\n",
        "                  if include_scores and current_scores:\n",
        "                      file_path = dataset_exporter.save_with_scores(current_dataset, current_scores, file_format, filename)\n",
        "                  else:\n",
        "                      file_path = dataset_exporter.save_dataset(current_dataset, file_format, filename)\n",
        "\n",
        "                  if file_path:\n",
        "                      return file_path, f\"✅ Dataset ready for download: {filename}\"\n",
        "                  else:\n",
        "                      return None, \"❌ Error creating file\"\n",
        "\n",
        "              except Exception as e:\n",
        "                  return None, f\"❌ Error exporting dataset: {str(e)}\"\n",
        "\n",
        "        export_btn.click(\n",
        "            export_dataset_with_file,\n",
        "            inputs=[file_format, filename, include_scores],\n",
        "            outputs=[download_file, export_status]\n",
        "        )\n",
        "\n",
        "        def update_field_choices():\n",
        "            \"\"\"Update field choices when dataset is generated - FIXED VERSION\"\"\"\n",
        "            global current_dataset\n",
        "\n",
        "            if not current_dataset:\n",
        "                print(\"DEBUG: No current dataset available\")\n",
        "                return gr.CheckboxGroup(choices=[], value=[])\n",
        "\n",
        "            try:\n",
        "                fields = synonym_permutator.identify_text_fields(current_dataset)\n",
        "                print(f\"DEBUG: Available fields for permutation: {fields}\")\n",
        "\n",
        "                if not fields:\n",
        "                    print(\"DEBUG: No text fields identified\")\n",
        "                    return gr.CheckboxGroup(choices=[], value=[])\n",
        "\n",
        "                return gr.CheckboxGroup(choices=fields, value=[])\n",
        "            except Exception as e:\n",
        "                print(f\"DEBUG: Error identifying fields: {str(e)}\")\n",
        "                return gr.CheckboxGroup(choices=[], value=[])\n",
        "\n",
        "        # Auto-update field choices\n",
        "        generate_dataset_btn.click(\n",
        "            generate_dataset_ui,\n",
        "            inputs=[generation_schema, generation_business_case, generation_model, generation_temperature, num_records, examples_input],\n",
        "            outputs=[generation_status, dataset_preview, record_count]\n",
        "        ).then(\n",
        "            update_field_choices,  # This should run after dataset generation\n",
        "            outputs=[fields_to_permute]\n",
        "        )\n",
        "\n",
        "    return interface\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "70d39131",
      "metadata": {
        "id": "70d39131"
      },
      "outputs": [],
      "source": [
        "# Launch the Gradio Interface\n",
        "interface = create_gradio_interface()\n",
        "interface.launch(debug=True, share=True)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "id": "212aa78a",
      "metadata": {
        "id": "212aa78a"
      },
      "source": [
        "## Example Workflow: Dataset\n",
        "\n",
        "This section demonstrates the complete pipeline using a pharmacogenomics (PGx) example.\n",
        "\n",
        "### Step 1: Schema Definition\n",
        "The default schema is already configured for pharmacogenomics data, including:\n",
        "- Patient demographics (age, gender, ethnicity)\n",
        "- Genetic variants (CYP2D6, CYP2C19, etc.)\n",
        "- Drug information (name, dosage)\n",
        "- Clinical outcomes (efficacy, adverse reactions)\n",
        "- Metabolizer status\n",
        "\n",
        "### Step 2: Dataset Generation\n",
        "1. Select a model (recommended: Llama 3.1 8B for quality, Llama 3.2 3B for speed)\n",
        "2. Set temperature (0.7 for balanced creativity/consistency)\n",
        "3. Specify number of records (50-100 for testing, 500+ for production)\n",
        "4. Add few-shot examples if needed\n",
        "\n",
        "### Step 3: Synonym Permutation\n",
        "1. Enable permutation checkbox\n",
        "2. Select text fields (e.g., drug_name, adverse_reaction)\n",
        "3. Set permutation rate (20-30% recommended)\n",
        "4. Apply to increase diversity\n",
        "\n",
        "### Step 4: Quality Scoring\n",
        "1. Select scoring model (can be different from generation model)\n",
        "2. Use lower temperature (0.3) for consistent scoring\n",
        "3. Review quality report and flagged records\n",
        "4. Regenerate if quality is insufficient\n",
        "\n",
        "### Step 5: Export\n",
        "1. Choose format (CSV for analysis, JSON for APIs)\n",
        "2. Include quality scores if needed\n",
        "3. Download your dataset\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "9789613e",
      "metadata": {
        "id": "9789613e"
      },
      "outputs": [],
      "source": [
        "# Testing and Validation Functions\n",
        "def test_schema_generation():\n",
        "    \"\"\"Test schema generation functionality\"\"\"\n",
        "    print(\"🧪 Testing Schema Generation...\")\n",
        "\n",
        "    # Test manual schema parsing\n",
        "    test_schema = \"\"\"1. patient_id (TEXT) - Unique patient identifier, example: PGX_001\n",
        "2. age (INT) - Patient age in years, example: 45\n",
        "3. drug_name (TEXT) - Medication name, example: Warfarin\"\"\"\n",
        "\n",
        "    parsed = schema_manager.parse_manual_schema(test_schema)\n",
        "    print(f\"✅ Manual schema parsing: {len(parsed)} fields\")\n",
        "\n",
        "    # Test commercial API schema generation\n",
        "    if \"openai\" in clients:\n",
        "        print(\"🔄 Testing OpenAI schema generation...\")\n",
        "        result = schema_manager.generate_schema_with_llm(\n",
        "            \"Generate a dataset for e-commerce customer analysis\",\n",
        "            \"Phi-3.5 Mini\",\n",
        "            1\n",
        "        )\n",
        "        print(f\"✅ OpenAI schema generation: {len(result)} characters\")\n",
        "\n",
        "    return True\n",
        "\n",
        "def test_dataset_generation():\n",
        "    \"\"\"Test dataset generation with small sample\"\"\"\n",
        "    print(\"🧪 Testing Dataset Generation...\")\n",
        "\n",
        "    # Use a simple schema for testing\n",
        "    test_schema = \"\"\"1. name (TEXT) - Customer name, example: John Doe\n",
        "2. age (INT) - Customer age, example: 30\n",
        "3. purchase_amount (FLOAT) - Purchase amount, example: 99.99\"\"\"\n",
        "\n",
        "    business_case = \"Generate customer purchase data for a retail store\"\n",
        "\n",
        "    # Test with commercial API if available\n",
        "    if \"openai\" in clients:\n",
        "        print(\"🔄 Testing OpenAI dataset generation...\")\n",
        "        status, records = dataset_generator.generate_dataset(\n",
        "            test_schema, business_case, \"GPT-5 Mini\", 1, 5, \"\"\n",
        "        )\n",
        "        print(f\"✅ OpenAI generation: {status}\")\n",
        "        if records:\n",
        "            print(f\"   Generated {len(records)} records\")\n",
        "\n",
        "    return True\n",
        "\n",
        "def test_synonym_permutation():\n",
        "    \"\"\"Test synonym permutation functionality\"\"\"\n",
        "    print(\"🧪 Testing Synonym Permutation...\")\n",
        "\n",
        "    # Test synonym lookup\n",
        "    test_word = \"excellent\"\n",
        "    synonyms = synonym_permutator.get_synonyms(test_word)\n",
        "    print(f\"✅ Synonym lookup for '{test_word}': {len(synonyms)} synonyms found\")\n",
        "\n",
        "    # Test text permutation\n",
        "    test_text = \"The patient showed excellent response to treatment\"\n",
        "    permuted = synonym_permutator.get_permutation_preview(test_text, 0.3)\n",
        "    print(f\"✅ Text permutation: '{test_text}' -> '{permuted}'\")\n",
        "\n",
        "    return True\n",
        "\n",
        "def test_quality_scoring():\n",
        "    \"\"\"Test quality scoring functionality\"\"\"\n",
        "    print(\"🧪 Testing Quality Scoring...\")\n",
        "\n",
        "    # Create test record\n",
        "    test_record = {\n",
        "        \"patient_id\": \"TEST_001\",\n",
        "        \"age\": 45,\n",
        "        \"drug_name\": \"Warfarin\",\n",
        "        \"efficacy_score\": 8\n",
        "    }\n",
        "\n",
        "    # Test quality rules extraction\n",
        "    rules = quality_scorer.extract_quality_rules(\n",
        "        \"Test business case\",\n",
        "        \"1. patient_id (TEXT) - Patient ID, example: P001\"\n",
        "    )\n",
        "    print(f\"✅ Quality rules extraction: {len(rules)} characters\")\n",
        "\n",
        "    return True\n",
        "\n",
        "def run_integration_test():\n",
        "    \"\"\"Run complete integration test\"\"\"\n",
        "    print(\"🚀 Running Integration Tests...\")\n",
        "    print(\"=\" * 50)\n",
        "\n",
        "    try:\n",
        "        test_schema_generation()\n",
        "        print()\n",
        "\n",
        "        test_dataset_generation()\n",
        "        print()\n",
        "\n",
        "        test_synonym_permutation()\n",
        "        print()\n",
        "\n",
        "        test_quality_scoring()\n",
        "        print()\n",
        "\n",
        "        print(\"✅ All integration tests passed!\")\n",
        "        return True\n",
        "\n",
        "    except Exception as e:\n",
        "        print(f\"❌ Integration test failed: {str(e)}\")\n",
        "        return False\n",
        "\n",
        "# Run integration tests\n",
        "run_integration_test()\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.10"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}